"""
download_ocean_nasa.py
----------------------
Downloads NASA Aqua MODIS Level-3 Mapped (L3m) daily SST and ocean color
variables at 4km resolution for the Louisiana coastal bounding box.

Designed specifically to preserve MNAR (Missing Not At Random) cloud gaps
as NaN rather than imputing them — the missingness structure is the subject
of analysis and must not be altered during acquisition.

Output: one CSV per day saved to the specified output directory,
named louisiana_YYYY-MM-DD.csv.

Source: NASA Ocean Biology Processing Group (OBPG) via
        https://oceandata.sci.gsfc.nasa.gov/
        
Usage:
    python download_ocean_nasa.py
    
    Modify the appkey variable at the bottom with your NASA EarthData
    application key if required.
"""

import os
import shutil
import tempfile
import re
import requests
import time
import concurrent.futures
from bs4 import BeautifulSoup
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
import xarray as xr


# --- Bounding Box: Louisiana Shellfish Harvest Areas ---
# Covers all 28 designated harvest areas along the Louisiana coast
LAT_N = 30.711210
LAT_S = 28.425522
LON_W = -95.575013
LON_E = -88.630602

# Coordinate rounding precision (3 decimal places ~ 0.1 km, sufficient for 4km grid)
LATLON_ROUND = 3


def _parse_date(day):
    """Return (human-readable string, year string, ISO string) for a date."""
    dt = datetime.strptime(day, "%Y-%m-%d") if isinstance(day, str) else day
    return dt.strftime("%d-%b-%Y"), dt.strftime("%Y"), dt.strftime("%Y-%m-%d")


def _fetch_aqua_modis_l3(day_str, nc_dir, appkey=""):
    """
    Downloads all available 4km SST L3m NetCDF files for a given day
    from the NASA OBPG Aqua-MODIS archive.
    
    Returns a list of local file paths for the downloaded files.
    """
    os.makedirs(nc_dir, exist_ok=True)
    dstr_human, year, _ = _parse_date(day_str)

    base_url = (
        f"https://oceandata.sci.gsfc.nasa.gov/directdataaccess/"
        f"Level-3%20Mapped/Aqua-MODIS/{year}/{dstr_human}/"
    )

    response = requests.get(base_url, timeout=30, headers={"User-Agent": "python"})
    response.raise_for_status()

    soup = BeautifulSoup(response.text, "html.parser")
    anchors = soup.select("tbody#tbody a[href]") or soup.select("a[href]")

    # Filter for 4km SST NetCDF files only
    target_files = [
        a for a in anchors
        if a.get_text(strip=True).endswith(".nc")
        and "4km" in a.get_text(strip=True)
        and "SST" in a.get_text(strip=True)
    ]

    downloaded = []
    for anchor in target_files:
        filename = anchor.get_text(strip=True) or anchor.get("href", "").split("/")[-1]
        file_url = requests.compat.urljoin(base_url, anchor["href"])

        if appkey and "appkey=" not in file_url:
            separator = "&" if "?" in file_url else "?"
            file_url += f"{separator}appkey={appkey}"

        local_path = os.path.join(nc_dir, filename)

        if not os.path.exists(local_path):
            with requests.get(
                file_url, stream=True, timeout=180,
                headers={"User-Agent": "python"}
            ) as dl:
                dl.raise_for_status()
                with open(local_path, "wb") as f:
                    for chunk in dl.iter_content(chunk_size=1 << 15):
                        if chunk:
                            f.write(chunk)

        downloaded.append(local_path)

    return downloaded


def _select_primary_variable(ds):
    """Select the largest data variable from an xarray Dataset."""
    if not ds.data_vars:
        raise ValueError("No data variables found in NetCDF file.")
    return max(ds.data_vars, key=lambda k: ds[k].size)


def _identify_spatial_dims(da):
    """Identify latitude and longitude dimension names in a DataArray."""
    lat_names = {"lat", "latitude", "y", "rows"}
    lon_names = {"lon", "longitude", "x", "cols", "columns"}

    lat_dim = next((d for d in da.dims if d.lower() in lat_names), None)
    lon_dim = next((d for d in da.dims if d.lower() in lon_names), None)

    if lat_dim and lon_dim:
        return lat_dim, lon_dim

    # Fallback: assume two largest non-temporal dimensions are lat/lon
    spatial_dims = sorted(
        [(da.sizes[d], d) for d in da.dims if d.lower() not in {"time", "band", "nband"}],
        reverse=True
    )
    if len(spatial_dims) >= 2:
        return spatial_dims[1][1], spatial_dims[0][1]

    raise ValueError("Could not identify spatial dimensions in file.")


def _extract_coordinates(ds, da, lat_dim, lon_dim):
    """
    Extract latitude and longitude arrays from dataset coordinates
    or reconstruct from global attributes if not present.
    """
    lat = np.asarray(ds.coords.get(lat_dim, ds.coords.get("lat", None)))
    lon = np.asarray(ds.coords.get(lon_dim, ds.coords.get("lon", None)))

    if lat is None or lon is None:
        attrs = {k.lower(): v for k, v in ds.attrs.items()}
        south = float(attrs.get("southernmost_latitude", attrs.get("geospatial_lat_min", -90)))
        north = float(attrs.get("northernmost_latitude", attrs.get("geospatial_lat_max", 90)))
        west  = float(attrs.get("westernmost_longitude", attrs.get("geospatial_lon_min", -180)))
        east  = float(attrs.get("easternmost_longitude", attrs.get("geospatial_lon_max", 180)))
        lat = np.linspace(south, north, da.sizes[lat_dim])
        lon = np.linspace(west, east, da.sizes[lon_dim])

    return lat, lon


def _nc_to_bbox_csv(nc_path, csv_dir):
    """
    Extract the Louisiana bounding box from a single NetCDF file and
    write to CSV. NaN values (cloud-masked pixels) are preserved as-is.
    """
    os.makedirs(csv_dir, exist_ok=True)
    ds = xr.open_dataset(nc_path)

    try:
        var = _select_primary_variable(ds)
        da = ds[var].squeeze(drop=True)

        # Apply fill value masking
        for fill_key in ("_FillValue", "fill_value", "missing_value"):
            if fill_key in da.attrs:
                da = da.where(da != da.attrs[fill_key])

        lat_dim, lon_dim = _identify_spatial_dims(da)
        lat, lon = _extract_coordinates(ds, da, lat_dim, lon_dim)

        # Normalize longitudes to [-180, 180]
        lon = np.where(lon > 180.0, lon - 360.0, lon)

        values = da.values.astype(float)

        # Clip to bounding box
        lat_idx = np.where((lat >= LAT_S) & (lat <= LAT_N))[0]
        lon_idx = np.where((lon >= LON_W) & (lon <= LON_E))[0]

        if lat_idx.size == 0 or lon_idx.size == 0:
            return None

        sub = values[np.ix_(lat_idx, lon_idx)]
        sub_lat = lat[lat_idx]
        sub_lon = lon[lon_idx]

        LON_grid, LAT_grid = np.meshgrid(sub_lon, sub_lat, indexing="xy")
        LON_grid = np.where(LON_grid > 180.0, LON_grid - 360.0, LON_grid)

        df = pd.DataFrame({
            "latitude": LAT_grid.ravel(),
            "longitude": LON_grid.ravel(),
            var: sub.ravel()
        })

        df["latitude"] = df["latitude"].round(LATLON_ROUND)
        df["longitude"] = df["longitude"].round(LATLON_ROUND)

        # Aggregate any duplicate grid points created by rounding
        df = df.groupby(["latitude", "longitude"], as_index=False)[var].mean()

        # Apply precise bounding box filter
        df = df[
            (df["latitude"] >= LAT_S) & (df["latitude"] <= LAT_N) &
            (df["longitude"] >= LON_W) & (df["longitude"] <= LON_E)
        ]

        if df.empty:
            return None

        stem = os.path.splitext(os.path.basename(nc_path))[0]
        out_path = os.path.join(csv_dir, f"{stem}.csv")
        df.to_csv(out_path, index=False)
        return out_path

    finally:
        ds.close()


def _merge_daily_csvs(csv_dir, output_path, date_str):
    """
    Merges all per-variable CSV files for a single day into one wide-format
    file, outer-joining on lat/lon to preserve all observed ocean pixels.
    """
    from glob import glob

    used_names = set()
    merged = None

    for path in sorted(glob(os.path.join(csv_dir, "*.csv"))):
        df = pd.read_csv(path, low_memory=False)
        if df.empty:
            continue

        col_map = {c.lower(): c for c in df.columns}
        lat_col = col_map.get("latitude") or col_map.get("lat")
        lon_col = col_map.get("longitude") or col_map.get("lon")
        other_cols = [c for c in df.columns if c not in {lat_col, lon_col}]

        if not lat_col or not lon_col or not other_cols:
            continue

        feature = other_cols[0]
        df = df[[lat_col, lon_col, feature]].rename(
            columns={lat_col: "latitude", lon_col: "longitude"}
        )
        df[feature] = pd.to_numeric(df[feature], errors="coerce")
        df = df.groupby(["latitude", "longitude"], as_index=False)[feature].mean()

        stem = re.sub(r"\s+", "_", os.path.splitext(os.path.basename(path))[0])
        col_name, suffix = stem, 2
        while col_name in used_names:
            col_name = f"{stem}__{suffix}"
            suffix += 1
        used_names.add(col_name)

        df.rename(columns={feature: col_name}, inplace=True)

        if merged is None:
            merged = df
        else:
            merged = pd.merge(merged, df, on=["latitude", "longitude"], how="outer", copy=False)

    if merged is None or merged.empty:
        raise RuntimeError(f"No valid data found in bounding box for {date_str}.")

    merged.sort_values(["latitude", "longitude"], inplace=True, kind="mergesort")

    # Standardize column names to match downstream analysis conventions
    rename_map = {
        "latitude": "Lat", "longitude": "Lon",
        "sst": "SST", "chlor_a": "chlor_a", "Kd_490": "Kd_490",
        "aot_869": "aot_869", "angstrom": "angstrom", "par": "par",
        "pic": "pic", "poc": "poc", "ipar": "ipar", "nflh": "nflh",
    }
    merged.rename(
        columns=lambda c: rename_map.get(
            c.split(".")[-3] if "." in c else c, c
        ),
        inplace=True
    )

    # Add date column for temporal indexing
    merged["Date"] = date_str

    # Remove rows where all non-coordinate columns are NaN (land pixels)
    coord_cols = {"Lat", "Lon", "Date"}
    feature_cols = [c for c in merged.columns if c not in coord_cols]
    merged.dropna(subset=feature_cols, how="all", inplace=True)

    merged.to_csv(output_path, index=False)
    return output_path


def build_daily_csv(day, outdir=".", appkey=""):
    """
    Downloads Aqua MODIS L3m data for a single day and produces
    one louisiana_YYYY-MM-DD.csv in the output directory.
    """
    os.makedirs(outdir, exist_ok=True)
    _, _, date_iso = _parse_date(day)
    final_path = os.path.join(outdir, f"louisiana_{date_iso}.csv")

    tmp_root = tempfile.mkdtemp(prefix="aqua_modis_")
    nc_dir   = os.path.join(tmp_root, "nc")
    csv_dir  = os.path.join(tmp_root, "csv")

    try:
        nc_files = _fetch_aqua_modis_l3(day, nc_dir, appkey)
        if not nc_files:
            return None

        any_valid = False
        for nc_path in nc_files:
            result = _nc_to_bbox_csv(nc_path, csv_dir)
            any_valid = any_valid or (result is not None)

        if not any_valid:
            return None

        _merge_daily_csvs(csv_dir, final_path, date_iso)
        return final_path

    except Exception as e:
        print(f"  Error on {day}: {e}")
        return None

    finally:
        try:
            shutil.rmtree(tmp_root)
        except Exception:
            pass


def run_date_range(start_date, end_date, outdir=".", appkey="", max_workers=10):
    """
    Downloads Aqua MODIS L3m data for all days in [start_date, end_date],
    processes each to a per-day CSV, then assembles a combined MNAR-padded
    long-format and wide-format dataset.

    Parameters
    ----------
    start_date : str  (YYYY-MM-DD)
    end_date   : str  (YYYY-MM-DD)
    outdir     : str  Output directory for daily CSVs and merged outputs
    appkey     : str  NASA EarthData application key (optional)
    max_workers: int  Concurrent download threads
    """
    start = datetime.strptime(start_date, "%Y-%m-%d").date()
    end   = datetime.strptime(end_date, "%Y-%m-%d").date()
    days  = [(start + timedelta(days=i)).strftime("%Y-%m-%d")
             for i in range((end - start).days + 1)]

    os.makedirs(outdir, exist_ok=True)
    print(f"Downloading {len(days)} days of Aqua MODIS L3m SST "
          f"({start_date} to {end_date})...")

    def process_day(d):
        try:
            return d, build_daily_csv(d, outdir=outdir, appkey=appkey)
        except Exception as e:
            return d, None

    success_paths = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_day, d): d for d in days}
        for future in concurrent.futures.as_completed(futures):
            d, path = future.result()
            if path:
                print(f"  [OK]      {d}")
                success_paths.append(path)
            else:
                print(f"  [MISSING] {d}  (cloud-covered or no data — NaN preserved)")

    if not success_paths:
        print("No files downloaded successfully.")
        return

    print(f"\nAssembling MNAR-padded dataset from {len(success_paths)} files...")

    frames = [pd.read_csv(p) for p in success_paths]
    master = pd.concat(frames, ignore_index=True)

    # Reconstruct full spatial-temporal grid to preserve MNAR structure.
    # Unobserved (cloud-masked) days appear as NaN rows rather than being absent.
    valid_coords = master[["Lat", "Lon"]].drop_duplicates()
    full_index = pd.MultiIndex.from_product(
        [valid_coords["Lat"].unique(), valid_coords["Lon"].unique(), days],
        names=["Lat", "Lon", "Date"]
    )
    full_grid = pd.DataFrame(index=full_index).reset_index()
    full_grid = pd.merge(full_grid, valid_coords, on=["Lat", "Lon"], how="inner")
    master = pd.merge(full_grid, master, on=["Lat", "Lon", "Date"], how="left")

    long_path = os.path.join(outdir, "Aqua_MODIS_MNAR_Long.csv")
    master.to_csv(long_path, index=False)
    print(f"\nLong-format output: {long_path}")
    print(f"  Rows: {len(master):,}  |  Days: {master['Date'].nunique()}  |  "
          f"Missing: {master['SST'].isnull().mean()*100:.1f}%")

    if "SST" in master.columns:
        wide = master.pivot(index=["Lat", "Lon"], columns="Date", values="SST").reset_index()
        wide.columns.name = None
        wide_path = os.path.join(outdir, "Aqua_MODIS_MNAR_Wide_SST.csv")
        wide.to_csv(wide_path, index=False)
        print(f"Wide-format SST output: {wide_path}")
        print(f"  Grid points: {wide.shape[0]:,}  |  Days: {wide.shape[1]-2}")


if __name__ == "__main__":
    appkey = ""  # Add your NASA EarthData application key here if required

    run_date_range(
        start_date="2022-01-01",
        end_date="2024-12-31",
        outdir=os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "ocean_l3_MNAR_3YR"),
        appkey=appkey,
        max_workers=10
    )